"""
图片提供者抽象基类
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, AsyncGenerator
import logging
import asyncio
from pathlib import Path

from ..models import (
    ImageInfo, ImageSearchRequest, ImageGenerationRequest, 
    ImageUploadRequest, ImageSearchResult, ImageOperationResult,
    ImageProvider, ImageSourceType
)

logger = logging.getLogger(__name__)


class BaseImageProvider(ABC):
    """图片提供者抽象基类"""
    
    def __init__(self, provider: ImageProvider, config: Dict[str, Any]):
        self.provider = provider
        self.config = config
        self.enabled = config.get('enabled', True)
        self.rate_limit = config.get('rate_limit', 60)  # 每分钟请求限制
        self.timeout = config.get('timeout', 30)  # 请求超时时间
        
        # 请求计数器（简单的速率限制）
        self._request_count = 0
        self._last_reset = asyncio.get_event_loop().time()
    
    @property
    @abstractmethod
    def source_type(self) -> ImageSourceType:
        """返回图片来源类型"""
        pass
    
    @property
    @abstractmethod
    def supported_operations(self) -> List[str]:
        """返回支持的操作列表"""
        pass
    
    async def _check_rate_limit(self) -> bool:
        """检查速率限制"""
        current_time = asyncio.get_event_loop().time()
        
        # 每分钟重置计数器
        if current_time - self._last_reset > 60:
            self._request_count = 0
            self._last_reset = current_time
        
        if self._request_count >= self.rate_limit:
            logger.warning(f"Rate limit exceeded for {self.provider}")
            return False
        
        self._request_count += 1
        return True
    
    async def _validate_config(self) -> bool:
        """验证配置"""
        return True
    
    async def health_check(self) -> Dict[str, Any]:
        """健康检查"""
        try:
            config_valid = await self._validate_config()
            rate_limit_ok = await self._check_rate_limit()
            
            return {
                'provider': self.provider,
                'enabled': self.enabled,
                'config_valid': config_valid,
                'rate_limit_ok': rate_limit_ok,
                'status': 'healthy' if (self.enabled and config_valid) else 'unhealthy'
            }
        except Exception as e:
            logger.error(f"Health check failed for {self.provider}: {e}")
            return {
                'provider': self.provider,
                'enabled': False,
                'status': 'error',
                'error': str(e)
            }


class ImageSearchProvider(BaseImageProvider):
    """图片搜索提供者基类"""
    
    @property
    def source_type(self) -> ImageSourceType:
        return ImageSourceType.WEB_SEARCH
    
    @property
    def supported_operations(self) -> List[str]:
        return ['search', 'get_details', 'download']
    
    @abstractmethod
    async def search(self, request: ImageSearchRequest) -> ImageSearchResult:
        """搜索图片"""
        pass
    
    @abstractmethod
    async def get_image_details(self, image_id: str) -> Optional[ImageInfo]:
        """获取图片详细信息"""
        pass
    
    @abstractmethod
    async def download_image(self, image_info: ImageInfo, save_path: Path) -> ImageOperationResult:
        """下载图片到本地"""
        pass
    
    async def batch_search(self, requests: List[ImageSearchRequest]) -> List[ImageSearchResult]:
        """批量搜索图片"""
        results = []
        for request in requests:
            try:
                if not await self._check_rate_limit():
                    break
                result = await self.search(request)
                results.append(result)
                # 添加延迟避免过快请求
                await asyncio.sleep(0.1)
            except Exception as e:
                logger.error(f"Batch search failed for request {request.query}: {e}")
                # 创建空结果
                results.append(ImageSearchResult(
                    images=[], total_count=0, page=request.page, 
                    per_page=request.per_page, has_next=False, has_prev=False,
                    search_time=0.0
                ))
        return results


class ImageGenerationProvider(BaseImageProvider):
    """AI图片生成提供者基类"""
    
    @property
    def source_type(self) -> ImageSourceType:
        return ImageSourceType.AI_GENERATED
    
    @property
    def supported_operations(self) -> List[str]:
        return ['generate', 'get_models', 'get_styles']
    
    @abstractmethod
    async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
        """生成图片"""
        pass
    
    @abstractmethod
    async def get_available_models(self) -> List[Dict[str, Any]]:
        """获取可用模型列表"""
        pass
    
    async def get_available_styles(self) -> List[Dict[str, Any]]:
        """获取可用样式列表"""
        return []
    
    async def batch_generate(self, requests: List[ImageGenerationRequest]) -> List[ImageOperationResult]:
        """批量生成图片"""
        results = []
        for request in requests:
            try:
                if not await self._check_rate_limit():
                    break
                result = await self.generate(request)
                results.append(result)
                # AI生成通常需要更长时间，添加更长延迟
                await asyncio.sleep(1.0)
            except Exception as e:
                logger.error(f"Batch generation failed for prompt {request.prompt}: {e}")
                results.append(ImageOperationResult(
                    success=False,
                    message=f"Generation failed: {str(e)}",
                    error_code="generation_error"
                ))
        return results


class LocalStorageProvider(BaseImageProvider):
    """本地存储提供者基类"""
    
    @property
    def source_type(self) -> ImageSourceType:
        return ImageSourceType.LOCAL_STORAGE
    
    @property
    def supported_operations(self) -> List[str]:
        return ['upload', 'list', 'get', 'delete', 'update']
    
    @abstractmethod
    async def upload(self, request: ImageUploadRequest, file_data: bytes) -> ImageOperationResult:
        """上传图片"""
        pass
    
    @abstractmethod
    async def list_images(self, 
                         page: int = 1, 
                         per_page: int = 20,
                         category: Optional[str] = None,
                         tags: Optional[List[str]] = None) -> ImageSearchResult:
        """列出图片"""
        pass
    
    @abstractmethod
    async def get_image(self, image_id: str) -> Optional[ImageInfo]:
        """获取图片信息"""
        pass
    
    @abstractmethod
    async def delete_image(self, image_id: str) -> ImageOperationResult:
        """删除图片"""
        pass
    
    @abstractmethod
    async def update_image(self, image_id: str, updates: Dict[str, Any]) -> ImageOperationResult:
        """更新图片信息"""
        pass
    
    async def search_local(self, query: str, **kwargs) -> ImageSearchResult:
        """在本地图片中搜索"""
        # 默认实现：通过标签和关键词搜索
        tags = kwargs.get('tags', [])
        if query:
            tags.append(query)
        
        return await self.list_images(
            page=kwargs.get('page', 1),
            per_page=kwargs.get('per_page', 20),
            tags=tags
        )


class ProviderRegistry:
    """提供者注册表"""
    
    def __init__(self):
        self._providers: Dict[ImageProvider, BaseImageProvider] = {}
        self._search_providers: List[ImageSearchProvider] = []
        self._generation_providers: List[ImageGenerationProvider] = []
        self._storage_providers: List[LocalStorageProvider] = []
    
    def register(self, provider: BaseImageProvider):
        """注册提供者"""
        # 检查是否已经注册了相同的提供者
        if provider.provider in self._providers:
            logger.debug(f"Provider {provider.provider} already registered, skipping")
            return

        self._providers[provider.provider] = provider

        if isinstance(provider, ImageSearchProvider):
            # 检查是否已经存在相同类型的搜索提供者
            existing = [p for p in self._search_providers if p.provider == provider.provider]
            if not existing:
                self._search_providers.append(provider)
        elif isinstance(provider, ImageGenerationProvider):
            # 检查是否已经存在相同类型的生成提供者
            existing = [p for p in self._generation_providers if p.provider == provider.provider]
            if not existing:
                self._generation_providers.append(provider)
        elif isinstance(provider, LocalStorageProvider):
            # 检查是否已经存在相同类型的存储提供者
            existing = [p for p in self._storage_providers if p.provider == provider.provider]
            if not existing:
                self._storage_providers.append(provider)
    
    def get_provider(self, provider: ImageProvider) -> Optional[BaseImageProvider]:
        """获取指定提供者"""
        return self._providers.get(provider)
    
    def get_search_providers(self, enabled_only: bool = True) -> List[ImageSearchProvider]:
        """获取搜索提供者"""
        if enabled_only:
            return [p for p in self._search_providers if p.enabled]
        return self._search_providers
    
    def get_generation_providers(self, enabled_only: bool = True) -> List[ImageGenerationProvider]:
        """获取生成提供者"""
        if enabled_only:
            return [p for p in self._generation_providers if p.enabled]
        return self._generation_providers
    
    def get_storage_providers(self, enabled_only: bool = True) -> List[LocalStorageProvider]:
        """获取存储提供者"""
        if enabled_only:
            return [p for p in self._storage_providers if p.enabled]
        return self._storage_providers
    
    async def health_check_all(self) -> Dict[str, Any]:
        """检查所有提供者健康状态"""
        results = {}
        for provider_name, provider in self._providers.items():
            results[provider_name] = await provider.health_check()
        return results


# 全局提供者注册表
provider_registry = ProviderRegistry()
